{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "images.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "private_outputs": true,
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.6"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "mt9dL5dIir8X"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "colab_type": "code",
        "id": "ufPx7EiCiqgR",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ucMoYase6URl"
      },
      "source": [
        "# 用 tf.data 加载图片"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_Wwu5SXZmEkB"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://tensorflow.google.cn/tutorials/load_data/images\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\" />在 tensorflow.google.cn 上查看</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/zh-cn/tutorials/load_data/images.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\" />在 Google Colab 运行</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/zh-cn/tutorials/load_data/images.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\" />在 Github 上查看源代码</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/zh-cn/tutorials/load_data/images.ipynb\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\" />下载此 notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GEe3i16tQPjo",
        "colab_type": "text"
      },
      "source": [
        "Note: 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为， 所以无法保证它们是最准确的，并且反映了最新的\n",
        "[官方英文文档](https://www.tensorflow.org/?hl=en)。如果您有改进此翻译的建议， 请提交 pull request 到\n",
        "[tensorflow/docs](https://github.com/tensorflow/docs) GitHub 仓库。要志愿地撰写或者审核译文，请加入\n",
        "[docs-zh-cn@tensorflow.org Google Group](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-zh-cn)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Oxw4WahM7DU9"
      },
      "source": [
        "本教程提供一个如何使用 `tf.data` 加载图片的简单例子。\n",
        "\n",
        "本例中使用的数据集分布在图片文件夹中，一个文件夹含有一类图片。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "hoQQiZDB6URn"
      },
      "source": [
        "## 配置"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "QGXxBuPyKJw1",
        "colab": {}
      },
      "source": [
        "from __future__ import absolute_import, division, print_function, unicode_literals\n",
        "\n",
        "try:\n",
        "  # Colab only\n",
        "  %tensorflow_version 2.x\n",
        "except Exception:\n",
        "    pass\n",
        "import tensorflow as tf"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "KT6CcaqgQewg",
        "colab": {}
      },
      "source": [
        "AUTOTUNE = tf.data.experimental.AUTOTUNE"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "rxndJHNC8YPM"
      },
      "source": [
        "## 下载并检查数据集"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wO0InzL66URu"
      },
      "source": [
        "### 检索图片\n",
        "\n",
        "在你开始任何训练之前，你将需要一组图片来教会网络你想要训练的新类别。你已经创建了一个文件夹，存储了最初使用的拥有创作共用许可的花卉照片。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "rN-Pc6Zd6awg",
        "colab": {}
      },
      "source": [
        "import pathlib\n",
        "data_root_orig = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',\n",
        "                                         fname='flower_photos', untar=True)\n",
        "data_root = pathlib.Path(data_root_orig)\n",
        "print(data_root)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "rFkFK74oO--g"
      },
      "source": [
        "下载了 218 MB 之后，你现在应该有花卉照片副本："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "7onR_lWE7Njj",
        "colab": {}
      },
      "source": [
        "for item in data_root.iterdir():\n",
        "  print(item)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "4yYX3ZRqGOuq",
        "colab": {}
      },
      "source": [
        "import random\n",
        "all_image_paths = list(data_root.glob('*/*'))\n",
        "all_image_paths = [str(path) for path in all_image_paths]\n",
        "random.shuffle(all_image_paths)\n",
        "\n",
        "image_count = len(all_image_paths)\n",
        "image_count"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "t_BbYnLjbltQ",
        "colab": {}
      },
      "source": [
        "all_image_paths[:10]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vkM-IpB-6URx"
      },
      "source": [
        "### 检查图片\n",
        "现在让我们快速浏览几张图片，这样你知道你在处理什么："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "wNGateQJ6UR1",
        "colab": {}
      },
      "source": [
        "import os\n",
        "attributions = (data_root/\"LICENSE.txt\").open(encoding='utf-8').readlines()[4:]\n",
        "attributions = [line.split(' CC-BY') for line in attributions]\n",
        "attributions = dict(attributions)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "jgowG2xu88Io",
        "colab": {}
      },
      "source": [
        "import IPython.display as display\n",
        "\n",
        "def caption_image(image_path):\n",
        "    image_rel = pathlib.Path(image_path).relative_to(data_root)\n",
        "    return \"Image (CC BY 2.0) \" + ' - '.join(attributions[str(image_rel)].split(' - ')[:-1])\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "YIjLi-nX0txI",
        "colab": {}
      },
      "source": [
        "for n in range(3):\n",
        "  image_path = random.choice(all_image_paths)\n",
        "  display.display(display.Image(image_path))\n",
        "  print(caption_image(image_path))\n",
        "  print()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "OaNOr-co3WKk"
      },
      "source": [
        "### 确定每张图片的标签"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-weOQpDw2Jnu"
      },
      "source": [
        "列出可用的标签："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "ssUZ7Qh96UR3",
        "colab": {}
      },
      "source": [
        "label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())\n",
        "label_names"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "9l_JEBql2OzS"
      },
      "source": [
        "为每个标签分配索引："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Y8pCV46CzPlp",
        "colab": {}
      },
      "source": [
        "label_to_index = dict((name, index) for index, name in enumerate(label_names))\n",
        "label_to_index"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "VkXsHg162T9F"
      },
      "source": [
        "创建一个列表，包含每个文件的标签索引："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "q62i1RBP4Q02",
        "colab": {}
      },
      "source": [
        "all_image_labels = [label_to_index[pathlib.Path(path).parent.name]\n",
        "                    for path in all_image_paths]\n",
        "\n",
        "print(\"First 10 labels indices: \", all_image_labels[:10])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "i5L09icm9iph"
      },
      "source": [
        "### 加载和格式化图片"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "SbqqRUS79ooq"
      },
      "source": [
        "TensorFlow 包含加载和处理图片时你需要的所有工具："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "jQZdySHvksOu",
        "colab": {}
      },
      "source": [
        "img_path = all_image_paths[0]\n",
        "img_path"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "2t2h2XCcmK1Y"
      },
      "source": [
        "以下是原始数据："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "LJfkyC_Qkt7A",
        "colab": {}
      },
      "source": [
        "img_raw = tf.io.read_file(img_path)\n",
        "print(repr(img_raw)[:100]+\"...\")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "opN8AVc8mSbz"
      },
      "source": [
        "将它解码为图像 tensor（张量）："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Tm0tdrlfk0Bb",
        "colab": {}
      },
      "source": [
        "img_tensor = tf.image.decode_image(img_raw)\n",
        "\n",
        "print(img_tensor.shape)\n",
        "print(img_tensor.dtype)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "3k-Of2Tfmbeq"
      },
      "source": [
        "根据你的模型调整其大小："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "XFpz-3_vlJgp",
        "colab": {}
      },
      "source": [
        "img_final = tf.image.resize(img_tensor, [192, 192])\n",
        "img_final = img_final/255.0\n",
        "print(img_final.shape)\n",
        "print(img_final.numpy().min())\n",
        "print(img_final.numpy().max())\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "aCsAa4Psl4AQ"
      },
      "source": [
        "将这些包装在一个简单的函数里，以备后用。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "HmUiZJNU73vA",
        "colab": {}
      },
      "source": [
        "def preprocess_image(image):\n",
        "  image = tf.image.decode_jpeg(image, channels=3)\n",
        "  image = tf.image.resize(image, [192, 192])\n",
        "  image /= 255.0  # normalize to [0,1] range\n",
        "\n",
        "  return image"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "einETrJnO-em",
        "colab": {}
      },
      "source": [
        "def load_and_preprocess_image(path):\n",
        "  image = tf.io.read_file(path)\n",
        "  return preprocess_image(image)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "3brWQcdtz78y",
        "colab": {}
      },
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "image_path = all_image_paths[0]\n",
        "label = all_image_labels[0]\n",
        "\n",
        "plt.imshow(load_and_preprocess_image(img_path))\n",
        "plt.grid(False)\n",
        "plt.xlabel(caption_image(img_path))\n",
        "plt.title(label_names[label].title())\n",
        "print()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "n2TCr1TQ8pA3"
      },
      "source": [
        "## 构建一个 `tf.data.Dataset`"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "6H9Z5Mq63nSH"
      },
      "source": [
        "### 一个图片数据集"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GN-s04s-6Luq"
      },
      "source": [
        "构建 `tf.data.Dataset` 最简单的方法就是使用 `from_tensor_slices` 方法。\n",
        "\n",
        "将字符串数组切片，得到一个字符串数据集："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "6oRPG3Jz3ie_",
        "colab": {}
      },
      "source": [
        "path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "uML4JeMmIAvO"
      },
      "source": [
        "`shapes（维数）` 和 `types（类型）` 描述数据集里每个数据项的内容。在这里是一组标量二进制字符串。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "mIsNflFbIK34",
        "colab": {}
      },
      "source": [
        "print(path_ds)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ZjyGcM8OwBJ2"
      },
      "source": [
        "现在创建一个新的数据集，通过在路径数据集上映射 `preprocess_image` 来动态加载和格式化图片。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "D1iba6f4khu-",
        "colab": {}
      },
      "source": [
        "image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "JLUPs2a-lEEJ",
        "colab": {}
      },
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "plt.figure(figsize=(8,8))\n",
        "for n, image in enumerate(image_ds.take(4)):\n",
        "  plt.subplot(2,2,n+1)\n",
        "  plt.imshow(image)\n",
        "  plt.grid(False)\n",
        "  plt.xticks([])\n",
        "  plt.yticks([])\n",
        "  plt.xlabel(caption_image(all_image_paths[n]))\n",
        "  plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "P6FNqPbxkbdx"
      },
      "source": [
        "### 一个`(图片, 标签)`对数据集"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "YgvrWLKG67-x"
      },
      "source": [
        "使用同样的 `from_tensor_slices` 方法你可以创建一个标签数据集："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "AgBsAiV06udj",
        "colab": {}
      },
      "source": [
        "label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels, tf.int64))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "HEsk5nN0vyeX",
        "colab": {}
      },
      "source": [
        "for label in label_ds.take(10):\n",
        "  print(label_names[label.numpy()])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "jHjgrEeTxyYz"
      },
      "source": [
        "由于这些数据集顺序相同，你可以将他们打包在一起得到一个`(图片, 标签)`对数据集："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "AOEWNMdQwsbN",
        "colab": {}
      },
      "source": [
        "image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "yA2F09SJLMuM"
      },
      "source": [
        "这个新数据集的 `shapes（维数）` 和 `types（类型）` 也是维数和类型的元组，用来描述每个字段："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "DuVYNinrLL-N",
        "colab": {}
      },
      "source": [
        "print(image_label_ds)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "2WYMikoPWOQX"
      },
      "source": [
        "注意：当你拥有形似 `all_image_labels` 和 `all_image_paths` 的数组，`tf.data.dataset.Dataset.zip` 的替代方法是将这对数组切片。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "HOFwZI-2WhzV",
        "colab": {}
      },
      "source": [
        "ds = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))\n",
        "\n",
        "# 元组被解压缩到映射函数的位置参数中\n",
        "def load_and_preprocess_from_path_label(path, label):\n",
        "  return load_and_preprocess_image(path), label\n",
        "\n",
        "image_label_ds = ds.map(load_and_preprocess_from_path_label)\n",
        "image_label_ds"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vYGCgJuR_9Qp"
      },
      "source": [
        "### 训练的基本方法"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wwZavzgsIytz"
      },
      "source": [
        "要使用此数据集训练模型，你将会想要数据：\n",
        "\n",
        "* 被充分打乱。\n",
        "* 被分割为 batch。\n",
        "* 永远重复。\n",
        "* 尽快提供 batch。\n",
        "\n",
        "使用 `tf.data` api 可以轻松添加这些功能。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "uZmZJx8ePw_5",
        "colab": {}
      },
      "source": [
        "BATCH_SIZE = 32\n",
        "\n",
        "# 设置一个和数据集大小一致的 shuffle buffer size（随机缓冲区大小）以保证数据\n",
        "# 被充分打乱。\n",
        "ds = image_label_ds.shuffle(buffer_size=image_count)\n",
        "ds = ds.repeat()\n",
        "ds = ds.batch(BATCH_SIZE)\n",
        "# 当模型在训练的时候，`prefetch` 使数据集在后台取得 batch。\n",
        "ds = ds.prefetch(buffer_size=AUTOTUNE)\n",
        "ds"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "6JsM-xHiFCuW"
      },
      "source": [
        "这里有一些注意事项：\n",
        "\n",
        "1. 顺序很重要。\n",
        "\n",
        "  * 在 `.repeat` 之后 `.shuffle`，会在 epoch 之间打乱数据（当有些数据出现两次的时候，其他数据还没有出现过）。\n",
        "  \n",
        "  * 在 `.batch` 之后 `.shuffle`，会打乱 batch 的顺序，但是不会在 batch 之间打乱数据。\n",
        "\n",
        "1. 你在完全打乱中使用和数据集大小一样的 `buffer_size（缓冲区大小）`。较大的缓冲区大小提供更好的随机化，但使用更多的内存，直到超过数据集大小。\n",
        "\n",
        "1. 在从随机缓冲区中拉取任何元素前，要先填满它。所以当你的 `Dataset（数据集）`启动的时候一个大的 `buffer_size（缓冲区大小）`可能会引起延迟。\n",
        "\n",
        "1. 在随机缓冲区完全为空之前，被打乱的数据集不会报告数据集的结尾。`Dataset（数据集）`由 `.repeat` 重新启动，导致需要再次等待随机缓冲区被填满。\n",
        "\n",
        "\n",
        "最后一点可以通过使用 `tf.data.Dataset.apply` 方法和融合过的 `tf.data.experimental.shuffle_and_repeat` 函数来解决:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Ocr6PybXNDoO",
        "colab": {}
      },
      "source": [
        "ds = image_label_ds.apply(\n",
        "  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))\n",
        "ds = ds.batch(BATCH_SIZE)\n",
        "ds = ds.prefetch(buffer_size=AUTOTUNE)\n",
        "ds"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GBBZMSuAmQVL"
      },
      "source": [
        "### 传递数据集至模型\n",
        "\n",
        "从 `tf.keras.applications` 取得 MobileNet v2 副本。\n",
        "\n",
        "该模型副本会被用于一个简单的迁移学习例子。\n",
        "\n",
        "设置 MobileNet 的权重为不可训练："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "KbJrXn9omO_g",
        "colab": {}
      },
      "source": [
        "mobile_net = tf.keras.applications.MobileNetV2(input_shape=(192, 192, 3), include_top=False)\n",
        "mobile_net.trainable=False"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Y7NVWiLF3Vbf"
      },
      "source": [
        "该模型期望它的输出被标准化至 `[-1,1]` 范围内：\n",
        "\n",
        "```\n",
        "help(keras_applications.mobilenet_v2.preprocess_input)\n",
        "```\n",
        "\n",
        "<pre>\n",
        "……\n",
        "该函数使用“Inception”预处理，将\n",
        "RGB 值从 [0, 255] 转化为 [-1, 1]\n",
        "……\n",
        "</pre>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "CboYya6LmdQI"
      },
      "source": [
        "在你将输出传递给 MobilNet 模型之前，你需要将其范围从 `[0,1]` 转化为 `[-1,1]`："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "SNOkHUGv3FYq",
        "colab": {}
      },
      "source": [
        "def change_range(image,label):\n",
        "  return 2*image-1, label\n",
        "\n",
        "keras_ds = ds.map(change_range)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "QDzZ3Nye5Rpv"
      },
      "source": [
        "MobileNet 为每张图片的特征返回一个 `6x6` 的空间网格。\n",
        "\n",
        "传递一个 batch 的图片给它，查看结果："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "OzAhGkEK6WuE",
        "colab": {}
      },
      "source": [
        "# 数据集可能需要几秒来启动，因为要填满其随机缓冲区。\n",
        "image_batch, label_batch = next(iter(keras_ds))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "LcFdiWpO5WbV",
        "colab": {}
      },
      "source": [
        "feature_map_batch = mobile_net(image_batch)\n",
        "print(feature_map_batch.shape)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vrbjEvaC5XmU"
      },
      "source": [
        "构建一个包装了 MobileNet 的模型并在 `tf.keras.layers.Dense` 输出层之前使用 `tf.keras.layers.GlobalAveragePooling2D` 来平均那些空间向量："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "X0ooIU9fNjPJ",
        "colab": {}
      },
      "source": [
        "model = tf.keras.Sequential([\n",
        "  mobile_net,\n",
        "  tf.keras.layers.GlobalAveragePooling2D(),\n",
        "  tf.keras.layers.Dense(len(label_names), activation = 'softmax')])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "foQYUJs97V4V"
      },
      "source": [
        "现在它产出符合预期 shape(维数)的输出："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "1nwYxvpj7ZEf",
        "colab": {}
      },
      "source": [
        "logit_batch = model(image_batch).numpy()\n",
        "\n",
        "print(\"min logit:\", logit_batch.min())\n",
        "print(\"max logit:\", logit_batch.max())\n",
        "print()\n",
        "\n",
        "print(\"Shape:\", logit_batch.shape)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "pFc4I_J2nNOJ"
      },
      "source": [
        "编译模型以描述训练过程："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "ZWGqLEWYRNvv",
        "colab": {}
      },
      "source": [
        "model.compile(optimizer=tf.keras.optimizers.Adam(),\n",
        "              loss='sparse_categorical_crossentropy',\n",
        "              metrics=[\"accuracy\"])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tF1mO6haBOSd"
      },
      "source": [
        "此处有两个可训练的变量 —— Dense 层中的 `weights（权重）` 和 `bias（偏差）`："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "pPQ5yqyKBJMm",
        "colab": {}
      },
      "source": [
        "len(model.trainable_variables)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "kug5Wg66UJjl",
        "colab": {}
      },
      "source": [
        "model.summary()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "f_glpYZ-nYC_"
      },
      "source": [
        "你已经准备好来训练模型了。\n",
        "\n",
        "注意，出于演示目的每一个 epoch 中你将只运行 3 step，但一般来说在传递给 `model.fit()` 之前你会指定 step 的真实数量，如下所示："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "AnXPRNWoTypI",
        "colab": {}
      },
      "source": [
        "steps_per_epoch=tf.math.ceil(len(all_image_paths)/BATCH_SIZE).numpy()\n",
        "steps_per_epoch"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "q_8sabaaSGAp",
        "colab": {}
      },
      "source": [
        "model.fit(ds, epochs=1, steps_per_epoch=3)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "UMVnoBcG_NlQ"
      },
      "source": [
        "## 性能\n",
        "\n",
        "注意：这部分只是展示一些可能帮助提升性能的简单技巧。深入指南，请看：[输入 pipeline（管道）的性能](https://tensorflow.google.cn/guide/performance/datasets)。\n",
        "\n",
        "上面使用的简单 pipeline（管道）在每个 epoch 中单独读取每个文件。在本地使用 CPU 训练时这个方法是可行的，但是可能不足以进行 GPU 训练并且完全不适合任何形式的分布式训练。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "oNmQqgGhLWie"
      },
      "source": [
        "要研究这点，首先构建一个简单的函数来检查数据集的性能："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "_gFVe1rp_MYr",
        "colab": {}
      },
      "source": [
        "import time\n",
        "default_timeit_steps = 2*steps_per_epoch+1\n",
        "\n",
        "def timeit(ds, steps=default_timeit_steps):\n",
        "  overall_start = time.time()\n",
        "  # 在开始计时之前\n",
        "  # 取得单个 batch 来填充 pipeline（管道）（填充随机缓冲区）\n",
        "  it = iter(ds.take(steps+1))\n",
        "  next(it)\n",
        "\n",
        "  start = time.time()\n",
        "  for i,(images,labels) in enumerate(it):\n",
        "    if i%10 == 0:\n",
        "      print('.',end='')\n",
        "  print()\n",
        "  end = time.time()\n",
        "\n",
        "  duration = end-start\n",
        "  print(\"{} batches: {} s\".format(steps, duration))\n",
        "  print(\"{:0.5f} Images/s\".format(BATCH_SIZE*steps/duration))\n",
        "  print(\"Total time: {}s\".format(end-overall_start))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "TYiOr4vdLcNX"
      },
      "source": [
        "当前数据集的性能是："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "ZDxLwMJOReVe",
        "colab": {}
      },
      "source": [
        "ds = image_label_ds.apply(\n",
        "  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))\n",
        "ds = ds.batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)\n",
        "ds"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "IjouTJadRxyp",
        "colab": {}
      },
      "source": [
        "timeit(ds)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "HsLlXMO7EWBR"
      },
      "source": [
        "### 缓存"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lV1NOn2zE2lR"
      },
      "source": [
        "使用 `tf.data.Dataset.cache` 在 epoch 之间轻松缓存计算结果。这是非常高效的，特别是当内存能容纳全部数据时。\n",
        "\n",
        "在被预处理之后（解码和调整大小），图片在此被缓存了："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "qj_U09xpDvOg",
        "colab": {}
      },
      "source": [
        "ds = image_label_ds.cache()\n",
        "ds = ds.apply(\n",
        "  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))\n",
        "ds = ds.batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)\n",
        "ds"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "rdxpvQ7VEo3y",
        "colab": {}
      },
      "source": [
        "timeit(ds)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "usIv7MqqZQps"
      },
      "source": [
        "使用内存缓存的一个缺点是必须在每次运行时重建缓存，这使得每次启动数据集时有相同的启动延迟："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "eKX6ergKb_xd",
        "colab": {}
      },
      "source": [
        "timeit(ds)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "jUzpG4lYNkN-"
      },
      "source": [
        "如果内存不够容纳数据，使用一个缓存文件："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "vIvF8K4GMq0g",
        "colab": {}
      },
      "source": [
        "ds = image_label_ds.cache(filename='./cache.tf-data')\n",
        "ds = ds.apply(\n",
        "  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))\n",
        "ds = ds.batch(BATCH_SIZE).prefetch(1)\n",
        "ds"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "eTIj6IOmM4yA",
        "colab": {}
      },
      "source": [
        "timeit(ds)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "qqo3dyB0Z4t2"
      },
      "source": [
        "这个缓存文件也有可快速重启数据集而无需重建缓存的优点。注意第二次快了多少："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "hZhVdR8MbaUj",
        "colab": {}
      },
      "source": [
        "timeit(ds)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "WqOVlf8tFrDU"
      },
      "source": [
        "### TFRecord 文件"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "y1llOTwWFzmR"
      },
      "source": [
        "#### 原始图片数据\n",
        "\n",
        "TFRecord 文件是一种用来存储一串二进制 blob 的简单格式。通过将多个示例打包进同一个文件内，TensorFlow 能够一次性读取多个示例，当使用一个远程存储服务，如 GCS 时，这对性能来说尤其重要。\n",
        "\n",
        "首先，从原始图片数据中构建出一个 TFRecord 文件："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "EqtARqKuHQLu",
        "colab": {}
      },
      "source": [
        "image_ds = tf.data.Dataset.from_tensor_slices(all_image_paths).map(tf.io.read_file)\n",
        "tfrec = tf.data.experimental.TFRecordWriter('images.tfrec')\n",
        "tfrec.write(image_ds)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "flR2GXWFKcO1"
      },
      "source": [
        "接着，构建一个从 TFRecord 文件读取的数据集，并使用你之前定义的 `preprocess_image` 函数对图像进行解码/重新格式化："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "j9PVUL2SFufn",
        "colab": {}
      },
      "source": [
        "image_ds = tf.data.TFRecordDataset('images.tfrec').map(preprocess_image)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "cRp1eZDRKzyN"
      },
      "source": [
        "压缩该数据集和你之前定义的标签数据集以得到期望的 `(图片,标签)` 对："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "7XI_nDU2KuhS",
        "colab": {}
      },
      "source": [
        "ds = tf.data.Dataset.zip((image_ds, label_ds))\n",
        "ds = ds.apply(\n",
        "  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))\n",
        "ds=ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)\n",
        "ds"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "3ReSapoPK22E",
        "colab": {}
      },
      "source": [
        "timeit(ds)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wb7VyoKNOMms"
      },
      "source": [
        "这比 `缓存` 版本慢，因为你还没有缓存预处理。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "NF9W-CTKkM-f"
      },
      "source": [
        "#### 序列化的 Tensor（张量）"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "J9HzljSPkxt0"
      },
      "source": [
        "要为 TFRecord 文件省去一些预处理过程，首先像之前一样制作一个处理过的图片数据集："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "OzS0Azukkjyw",
        "colab": {}
      },
      "source": [
        "paths_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)\n",
        "image_ds = paths_ds.map(load_and_preprocess_image)\n",
        "image_ds"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "onWOwLpYlzJQ"
      },
      "source": [
        "现在你有一个 tensor（张量）数据集，而不是一个 `.jpeg` 字符串数据集。\n",
        "\n",
        "要将此序列化至一个 TFRecord 文件你首先将该 tensor（张量）数据集转化为一个字符串数据集："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "xxZSwnRllyf0",
        "colab": {}
      },
      "source": [
        "ds = image_ds.map(tf.io.serialize_tensor)\n",
        "ds"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "w9N6hJWAkKPC",
        "colab": {}
      },
      "source": [
        "tfrec = tf.data.experimental.TFRecordWriter('images.tfrec')\n",
        "tfrec.write(ds)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "OlFc9dJSmcx0"
      },
      "source": [
        "有了被缓存的预处理，就能从 TFrecord 文件高效地加载数据——只需记得在使用它之前反序列化："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "BsqFyTBFmSCZ",
        "colab": {}
      },
      "source": [
        "ds = tf.data.TFRecordDataset('images.tfrec')\n",
        "\n",
        "def parse(x):\n",
        "  result = tf.io.parse_tensor(x, out_type=tf.float32)\n",
        "  result = tf.reshape(result, [192, 192, 3])\n",
        "  return result\n",
        "\n",
        "ds = ds.map(parse, num_parallel_calls=AUTOTUNE)\n",
        "ds"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "OPs_sLV9pQg5"
      },
      "source": [
        "现在，像之前一样添加标签和进行相同的标准操作："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "XYxBwaLYnGop",
        "colab": {}
      },
      "source": [
        "ds = tf.data.Dataset.zip((ds, label_ds))\n",
        "ds = ds.apply(\n",
        "  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))\n",
        "ds=ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)\n",
        "ds"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "W8X6RmGan1-P",
        "colab": {}
      },
      "source": [
        "timeit(ds)"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}
